package com.zhangc.sqldruid.controller;

import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.sql.DataSource;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.http.MediaType;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.web.HttpMediaTypeNotSupportedException;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.View;
import org.springframework.web.servlet.mvc.Controller;

import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLDeleteStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor;
import com.alibaba.druid.sql.parser.ParserException;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.druid.util.Utils;
import com.alibaba.druid.wall.Violation;
import com.alibaba.druid.wall.WallConfig;
import com.alibaba.druid.wall.WallVisitor;
import com.alibaba.druid.wall.spi.MySqlWallProvider;
import com.alibaba.druid.wall.violation.SyntaxErrorViolation;
import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

/**
 * SQL 语句执行网关,校验通过的SQL才能够执行
 *
 * @author
 */
public class SQLSelectHttpRequestHandler implements Controller, InitializingBean {

    private static final Logger LOGGER = LoggerFactory.getLogger(SQLSelectHttpRequestHandler.class);

    @Autowired
    @Qualifier("selfObjectMapper")
    ObjectMapper objectMapper;

    private MySqlWallProvider sqlWallProvider;

    private Map<String, JdbcTemplate> jdbcTemplates = Maps.newConcurrentMap();

    @Autowired
    public DataSource druidDataSource;

    /**
     * SQL 语句执行网关
     *
     * @param sql
     * @param violations
     * @return
     * @throws Exception
     */
    private List<SQLStatement> filter(String sql, List<Violation> violations) {
        try {
            MySqlStatementParser parser = new MySqlStatementParser(sql);
            List<SQLStatement> stmtList = parser.parseStatementList();

            for (SQLStatement stmt : stmtList) {
                WallVisitor wallVisitor = sqlWallProvider.createWallVisitor();
                stmt.accept(wallVisitor);
                if (!wallVisitor.getViolations().isEmpty()) {
                    violations.addAll(wallVisitor.getViolations());
                }
            }
            return stmtList;
        } catch (ParserException e) {
            violations.add(new SyntaxErrorViolation(e, sql));
        }
        return null;
    }

    /*
     * (non-Javadoc)
     * @see org.springframework.web.servlet.mvc.Controller#handleRequest(javax.servlet.http.HttpServletRequest,
     * javax.servlet.http.HttpServletResponse)
     */
    @Override
    public ModelAndView handleRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
        ModelAndView modelAndView = new ModelAndView();
        modelAndView.setView(new View() {
            @Override
            public String getContentType() {
                return "application/json; charset=utf-8";
            }

            @Override
            public void render(Map<String, ?> model, HttpServletRequest request, HttpServletResponse response)
                    throws Exception {
                response.setContentType(getContentType());
                //ObjectMapper objectMapper = new ObjectMapper();
                //ObjectMapper objectMapper = mapperBuilder.build();
                JsonGenerator generator = objectMapper.getFactory().createGenerator(response.getOutputStream(),
                        JsonEncoding.UTF8);
                objectMapper.writeValue(generator, model);
                response.getOutputStream().flush();
                generator.close();
            }
        });
        try {
            MediaType mediaType = MediaType.parseMediaType(request.getContentType());
            if (MediaType.parseMediaType("application/sql").includes(mediaType)) {
                Reader reader = new InputStreamReader(request.getInputStream(), mediaType.getCharset());
                String input = Utils.read(reader);
                JdbcUtils.close(reader);
                List<Violation> violations = new ArrayList<Violation>();
                List<SQLStatement> stmts = filter(input, violations);
                if (!violations.isEmpty()) {
                    modelAndView.addObject("error", violations);
                } else {
                    List<Object> results = Lists.newArrayList();
                    for (SQLStatement stmt : stmts) {
                        StringBuilder out = new StringBuilder();
                        MySqlOutputVisitor visitor = new MySqlOutputVisitor(out);
                        stmt.accept(visitor);
                        String shardId = request.getParameter("shard");
                        JdbcTemplate jdbcTemplate;
                        Map<String, Object> result = Maps.newHashMap();
                        if (StringUtils.isNotEmpty(shardId) && null != (jdbcTemplate = this.jdbcTemplates.get(shardId))) {
                            if (stmt instanceof SQLSelectStatement) {
                                result.put("data", jdbcTemplate.queryForList(out.toString()));
                            } else if (stmt instanceof SQLDeleteStatement || stmt instanceof SQLUpdateStatement) {
                                result.put("affected", jdbcTemplate.update(out.toString()));
                            } else {
                                jdbcTemplate.execute(out.toString());
                            }
                        } else {
                            result.put("error", "one of " + JSON.toJSONString(jdbcTemplates.keySet()));
                        }
                        results.add(result);
                    }
                    modelAndView.addObject("results", results);
                }
            } else {
                throw new HttpMediaTypeNotSupportedException(mediaType + "not supported");
            }

        } catch (Exception e) {
            modelAndView.addObject("exception", e.getClass().getSimpleName() + ":" + e.getMessage());
            LOGGER.error("error.", e);
        }
        return modelAndView;
    }

    @Autowired
    ResourceLoader resourceLoader;

    /*
     * (non-Javadoc)
     * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet()
     */
    @Override
    public void afterPropertiesSet() throws Exception {
        WallConfig config = new WallConfig();
        config.loadConfig("/druid/wall/mysql");
        //classpath:conf/druid/wall/mysql/props.properties;
        Properties pps = new Properties();
        Resource resource = resourceLoader.getResource("classpath:props.properties");
        InputStream is = resource.getInputStream();
        pps.load(is);
        config.configFromProperties(pps);
        sqlWallProvider = new MySqlWallProvider(config);
        JdbcTemplate sqlExecutor = new JdbcTemplate();
        sqlExecutor.setDataSource(druidDataSource);
        // 默认最大200
        sqlExecutor.setMaxRows(200);
        //这里暂时写死shard0，方便以后扩展多库
        jdbcTemplates.put("shard0", sqlExecutor);
        /*for (Shard shard : this.shardRegister.getShards()) {
            JdbcTemplate sqlExecutor = new JdbcTemplate();
            sqlExecutor.setDataSource(shard.getDataSource());
            // 默认最大200
            sqlExecutor.setMaxRows(200);
            jdbcTemplates.put(shard.getId(), sqlExecutor);
        }*/
    }

}
